package li.cil.oc.common.asm

import cpw.mods.fml.common.asm.transformers.deobf.FMLDeobfuscatingRemapper
import cpw.mods.fml.relauncher.IFMLLoadingPlugin.TransformerExclusions
import java.util.logging.{Level, Logger}
import li.cil.oc.common.asm.template.SimpleComponentImpl
import li.cil.oc.util.mods.Mods
import net.minecraft.launchwrapper.{LaunchClassLoader, IClassTransformer}
import org.objectweb.asm.tree._
import org.objectweb.asm.{Opcodes, ClassWriter, ClassReader}
import scala.collection.convert.WrapAsJava._
import scala.collection.convert.WrapAsScala._

@TransformerExclusions(Array("li.cil.oc.common.asm"))
class ClassTransformer extends IClassTransformer {
  val loader = classOf[ClassTransformer].getClassLoader.asInstanceOf[LaunchClassLoader]

  val log = Logger.getLogger("OpenComputers")

  override def transform(name: String, transformedName: String, basicClass: Array[Byte]): Array[Byte] = {
    var transformedClass = basicClass
    try {
      if (name == "li.cil.oc.common.tileentity.traits.Computer" || name == "li.cil.oc.common.tileentity.Rack") {
        transformedClass = ensureStargateTechCompatibility(transformedClass)
      }
      if (transformedClass != null
        && !name.startsWith("""net.minecraft.""")
        && !name.startsWith("""net.minecraftforge.""")
        && !name.startsWith("""li.cil.oc.common.asm.""")) {
        if (name.startsWith("""li.cil.oc.""")) {
          // Strip foreign interfaces from scala generated classes. This is
          // primarily intended to clean up mix-ins / synthetic classes
          // generated by Scala.
          val classNode = newClassNode(transformedClass)
          val missingInterfaces = classNode.interfaces.filter(!classExists(_))
          for (interfaceName <- missingInterfaces) {
            log.fine(s"Stripping interface $interfaceName from class $name because it is missing.")
          }
          classNode.interfaces.removeAll(missingInterfaces)

          val missingClasses = classNode.innerClasses.filter(clazz => clazz.outerName != null && !classExists(clazz.outerName))
          for (innerClass <- missingClasses) {
            log.fine(s"Stripping inner class ${innerClass.name} from class $name because its type ${innerClass.outerName} is missing.")
          }
          classNode.innerClasses.removeAll(missingClasses)

          val incompleteMethods = classNode.methods.filter(method => missingFromSignature(method.desc).nonEmpty)
          for (method <- incompleteMethods) {
            val missing = missingFromSignature(method.desc).mkString(", ")
            log.fine(s"Stripping method ${method.name} from class $name because the following types in its signature are missing: $missing")
          }
          classNode.methods.removeAll(incompleteMethods)
          transformedClass = writeClass(classNode)
        }
        {
          val classNode = newClassNode(transformedClass)
          if (classNode.interfaces.contains("li/cil/oc/api/network/SimpleComponent")) {
            try {
              transformedClass = injectEnvironmentImplementation(classNode)
              log.info(s"Successfully injected component logic into class $name.")
            }
            catch {
              case e: Throwable =>
                log.log(Level.WARNING, s"Failed injecting component logic into class $name.", e)
            }
          }
        }
      }
      transformedClass
    }
    catch {
      case t: Throwable =>
        log.log(Level.WARNING, "Something went wrong!", t)
        basicClass
    }
  }

  private def classExists(name: String) = {
    loader.getClassBytes(name) != null ||
      loader.getClassBytes(FMLDeobfuscatingRemapper.INSTANCE.unmap(name)) != null ||
      (try loader.findClass(name.replace('/', '.')) != null catch {
        case _: ClassNotFoundException => false
      })
  }

  private def missingFromSignature(desc: String) = {
    """L([^;]+);""".r.findAllMatchIn(desc).map(_.group(1)).filter(!classExists(_))
  }

  def ensureStargateTechCompatibility(basicClass: Array[Byte]): Array[Byte] = {
    if (!Mods.StargateTech2.isAvailable) {
      // No SGT2 or version is too old, abstract bus API doesn't exist.
      val classNode = newClassNode(basicClass)
      classNode.interfaces.remove("stargatetech2/api/bus/IBusDevice")
      writeClass(classNode)
    }
    else basicClass
  }

  def injectEnvironmentImplementation(classNode: ClassNode): Array[Byte] = {
    log.fine(s"Injecting methods from Environment interface into ${classNode.name}.")
    if (!isTileEntity(classNode)) {
      throw new InjectionFailedException("Found SimpleComponent on something that isn't a tile entity, ignoring.")
    }

    val template = classNodeFor("li/cil/oc/common/asm/template/SimpleEnvironment")
    if (template == null) {
      throw new InjectionFailedException("Could not find SimpleComponent template!")
    }

    def inject(methodName: String, signature: String, required: Boolean = false) {
      def filter(method: MethodNode) = method.name == methodName && method.desc == signature
      if (classNode.methods.exists(filter)) {
        if (required) {
          throw new InjectionFailedException(s"Could not inject method '$methodName$signature' because it was already present!")
        }
      }
      else template.methods.find(filter) match {
        case Some(method) => classNode.methods.add(method)
        case _ => throw new AssertionError()
      }
    }
    inject("node", "()Lli/cil/oc/api/network/Node;", required = true)
    inject("onConnect", "(Lli/cil/oc/api/network/Node;)V")
    inject("onDisconnect", "(Lli/cil/oc/api/network/Node;)V")
    inject("onMessage", "(Lli/cil/oc/api/network/Message;)V")

    log.fine("Injecting / wrapping overrides for required tile entity methods.")
    def replace(methodName: String, methodNameSrg: String, desc: String) {
      val mapper = FMLDeobfuscatingRemapper.INSTANCE
      def filter(method: MethodNode) = {
        val descDeObf = mapper.mapMethodDesc(method.desc)
        val methodNameDeObf = mapper.mapMethodName(tileEntityNameObfed, method.name, method.desc)
        val areSamePlain = method.name + descDeObf == methodName + desc
        val areSameDeObf = methodNameDeObf + descDeObf == methodNameSrg + desc
        areSamePlain || areSameDeObf
      }
      if (classNode.methods.exists(method => method.name == methodName + SimpleComponentImpl.PostFix && mapper.mapMethodDesc(method.desc) == desc)) {
        throw new InjectionFailedException(s"Delegator method name '${methodName + SimpleComponentImpl.PostFix}' is already in use.")
      }
      classNode.methods.find(filter) match {
        case Some(method) =>
          log.fine(s"Found original implementation of '$methodName', wrapping.")
          method.name = methodName + SimpleComponentImpl.PostFix
        case _ =>
          log.fine(s"No original implementation of '$methodName', will inject override.")
          def ensureNonFinalIn(name: String) {
            if (name != null) {
              val node = classNodeFor(name)
              if (node != null) {
                node.methods.find(filter) match {
                  case Some(method) =>
                    if ((method.access & Opcodes.ACC_FINAL) != 0) {
                      throw new InjectionFailedException(s"Method '$methodName' is final in superclass ${node.name.replace('/', '.')}.")
                    }
                  case _ =>
                }
                ensureNonFinalIn(node.superName)
              }
            }
          }
          ensureNonFinalIn(classNode.superName)
          template.methods.find(_.name == methodName + SimpleComponentImpl.PostFix) match {
            case Some(method) => classNode.methods.add(method)
            case _ => throw new AssertionError(s"Couldn't find '${methodName + SimpleComponentImpl.PostFix}' in template implementation.")
          }
      }
      template.methods.find(filter) match {
        case Some(method) => classNode.methods.add(method)
        case _ => throw new AssertionError(s"Couldn't find '$methodName' in template implementation.")
      }
    }
    replace("validate", "func_70312_q", "()V")
    replace("invalidate", "func_70313_j", "()V")
    replace("onChunkUnload", "func_76631_c", "()V")
    replace("readFromNBT", "func_70307_a", "(Lnet/minecraft/nbt/NBTTagCompound;)V")
    replace("writeToNBT", "func_70310_b", "(Lnet/minecraft/nbt/NBTTagCompound;)V")

    log.fine("Injecting interface.")
    classNode.interfaces.add("li/cil/oc/common/asm/template/SimpleComponentImpl")

    writeClass(classNode, ClassWriter.COMPUTE_MAXS | ClassWriter.COMPUTE_FRAMES)
  }

  val tileEntityNamePlain = "net/minecraft/tileentity/TileEntity"
  val tileEntityNameObfed = FMLDeobfuscatingRemapper.INSTANCE.unmap(tileEntityNamePlain)

  def isTileEntity(classNode: ClassNode): Boolean = {
    if (classNode == null) false
    else {
      log.finer(s"Checking if class ${classNode.name} is a TileEntity...")
      classNode.name == tileEntityNamePlain || classNode.name == tileEntityNameObfed ||
        (classNode.superName != null && isTileEntity(classNodeFor(classNode.superName)))
    }
  }

  def classNodeFor(name: String) = {
    val namePlain = name.replace('/', '.')
    val bytes = loader.getClassBytes(namePlain)
    if (bytes != null) newClassNode(bytes)
    else {
      val nameObfed = FMLDeobfuscatingRemapper.INSTANCE.unmap(name).replace('/', '.')
      val bytes = loader.getClassBytes(nameObfed)
      if (bytes == null) null
      else newClassNode(bytes)
    }
  }

  def newClassNode(data: Array[Byte]) = {
    val classNode = new ClassNode()
    new ClassReader(data).accept(classNode, 0)
    classNode
  }

  def writeClass(classNode: ClassNode, flags: Int = ClassWriter.COMPUTE_MAXS) = {
    val writer = new ClassWriter(flags)
    classNode.accept(writer)
    writer.toByteArray
  }

  class InjectionFailedException(message: String) extends Exception(message)

}
